{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"top\"></a><img src=\"images/chisel_1024.png\" alt=\"Chisel logo\" style=\"width:480px;\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 3.3: Higher-Order Functions\n",
    "**Prev: [Interlude: Chisel Standard Library](3.2_interlude.ipynb)**<br>\n",
    "**Next: [Functional Programming](3.4_functional_programming.ipynb)**\n",
    "\n",
    "## Motivation\n",
    "Those pesky `for` loops in the previous module are verbose and defeat the purpose of functional programming! In this module, your generators will get funct-ky.\n",
    "\n",
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val path = System.getProperty(\"user.dir\") + \"/source/load-ivy.sc\"\n",
    "interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import chisel3._\n",
    "import chisel3.util._\n",
    "import chisel3.tester._\n",
    "import chisel3.tester.RawTester.test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# A Tale of Two FIRs <a name=\"compact-fir\"></a>\n",
    "From the last module, we had the convolution part of the FIR filter written like this:\n",
    "\n",
    "```scala\n",
    "val muls = Wire(Vec(length, UInt(8.W)))\n",
    "for(i <- 0 until length) {\n",
    "  if(i == 0) muls(i) := io.in * io.consts(i)\n",
    "  else       muls(i) := regs(i - 1) * io.consts(i)\n",
    "}\n",
    "\n",
    "val scan = Wire(Vec(length, UInt(8.W)))\n",
    "for(i <- 0 until length) {\n",
    "  if(i == 0) scan(i) := muls(i)\n",
    "  else scan(i) := muls(i) + scan(i - 1)\n",
    "}\n",
    "\n",
    "io.out := scan(length - 1)\n",
    "```\n",
    "\n",
    "As a recap, the idea is to multiply each element of `io.in` with the corresponding element of `io.consts`, and store it in `muls`.\n",
    "Then, the elements in `muls` are accumulated into `scan`, with `scan(0) = muls(0)`, `scan(1) = scan(0) + muls(1) = muls(0) + muls(1)`, and in general `scan(n) = scan(n-1) + muls(n) = muls(0) + ... + muls(n-1) + muls(n)`.\n",
    "The last element in `scan` (equal to the sum of all `muls`) is assigned to `io.out`.\n",
    "\n",
    "However, it's very verbose for what might be considered quite a simple operation. In fact, all that could be written in one line:\n",
    "\n",
    "```scala\n",
    "io.out := (taps zip io.consts).map { case (a, b) => a * b }.reduce(_ + _)\n",
    "```\n",
    "\n",
    "What is it doing?! Let's break it down:\n",
    "- assume `taps` is the list of all samples, with `taps(0) = io.in`, `taps(1) = regs(0)`, etc.\n",
    "- `(taps zip io.consts)` takes two lists, `taps` and `io.consts`, and combines them into one list where each element is a tuple of the elements at the inputs at the corresponding position. Concretely, its value would be `[(taps(0), io.consts(0)), (taps(1), io.consts(1)), ..., (taps(n), io.consts(n))]`. Remember that periods are optional, so this is equivalent to `(taps.zip(io.consts))`.\n",
    "- `.map { case (a, b) => a * b }` applies the anonymous function (takes a tuple of two elements returns their product) to the elements of the list, and returns the result. In this case, the result is equivalent to `muls` in the verbose example, and has the value `[taps(0) * io.consts(0), taps(1) * io.consts(1), ..., taps(n) * io.consts(n)]`. You'll revisit anonymous functions in the next module. For now, just learn this syntax.\n",
    "- Finally, `.reduce(_ + _)` also applies the function (addition of elements) to elements of the list. However, it takes two arguments: the first is the current accumulation, and the second is the list element (in the first iteration, it just adds the first two elements). These are given by the two underscores in the parentheses. The result would then be, assuming left-to-right traversal, `(((muls(0) + muls(1)) + muls(2)) + ...) + muls(n)`, with the result of deeper-nested parentheses evaluated first. This is the output of the convolution.\n",
    "\n",
    "---\n",
    "# Functions as Arguments\n",
    "Formally, functions like `map` and `reduce` are called _higher-order functions_ : they are functions that take functions as arguments.\n",
    "As it turns out (and hopefully, as you can see from the above example), these are very powerful constructs that encapsulate a general computational pattern, allowing you to concentrate on the application logic instead of flow control, and resulting in very concise code.\n",
    "\n",
    "## Different ways of specifying functions\n",
    "You may have noticed that there were two ways of specifying functions in the examples above:\n",
    "- For functions where each argument is referred to exactly once, you *may* be able to use an underscore (`_`) to refer to each argument. In the example above, the `reduce` argument function took two arguments and could be specified as `_ + _`. While convenient, this is subject to an additional set of arcane rules, so if it does't work, try:\n",
    "- Specifying the inputs argument list explicitly. The reduce could have been explicitly written as `(a, b) => a + b`, with the general form of putting the argument list in parentheses, followed by `=>`, followed by the function body referring to those arguments.\n",
    "- When tuple unpacking is needed, using a `case` statement, as in `case (a, b) => a * b`. That takes a single argument, a tuple of two elements, and unpacks it into variables `a` and `b`, which can then be used in the function body.\n",
    "\n",
    "## Practice in Scala\n",
    "In the last module, we've seen major classes in the Scala Collections API, like `List`s.\n",
    "These higher-order functions are part of these APIs - and in fact, the above example uses the `map` and `reduce` API on `List`s.\n",
    "In this section, we'll familiarize ourselves with these methods through examples and exercises.\n",
    "In these examples, we'll operate on Scala numbers (`Int`s) for the sake of simpliciy and clarify, but because Chisel operators behave similarly, the concepts should generalize.\n",
    "\n",
    "<span style=\"color:blue\">**Example: Map**</span><br>\n",
    "`List[A].map` has type signature `map[B](f: (A) ⇒ B): List[B]`. You'll learn more about types in a later module. For now, think of types A and B as `Int`s or `UInt`s, meaning they could be software or hardware types.\n",
    "\n",
    "In plain English, it takes an argument of type `(f: (A) ⇒ B)`, or a function that takes one argument of type `A` (the same type as the element of the input List) and returns a value of type `B` (which can be anything). `map` then returns a new list of type `B` (the return type of the argument function).\n",
    "\n",
    "As we've already explained the behavior of List in the FIR explanation, let's get straight into the examples and exercises:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "println(List(1, 2, 3, 4).map(x => x + 1))  // explicit argument list in function\n",
    "println(List(1, 2, 3, 4).map(_ + 1))  // equivalent to the above, but implicit arguments\n",
    "println(List(1, 2, 3, 4).map(_.toString + \"a\"))  // the output element type can be different from the input element type\n",
    "\n",
    "println(List((1, 5), (2, 6), (3, 7), (4, 8)).map { case (x, y) => x*y })  // this unpacks a tuple, note use of curly braces\n",
    "\n",
    "// Related: Scala has a syntax for constructing lists of sequential numbers\n",
    "println(0 to 10)  // to is inclusive , the end point is part of the result\n",
    "println(0 until 10)  // until is exclusive at the end, the end point is not part of the result\n",
    "\n",
    "// Those largely behave like lists, and can be useful for generating indices:\n",
    "val myList = List(\"a\", \"b\", \"c\", \"d\")\n",
    "println((0 until 4).map(myList(_)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">**Exercise: Map**</span><br><a name=\"map-exercise\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Now you try: \n",
    "// Fill in the blanks (the ???) such that this doubles the elements of the input list.\n",
    "// This should return: List(2, 4, 6, 8)\n",
    "println(List(1, 2, 3, 4).map(???))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:blue\">**Example: zipWithIndex**</span><br>\n",
    "`List.zipWithIndex` has type signature `zipWithIndex: List[(A, Int)]`.\n",
    "\n",
    "It takes no arguments, but returns a list where each element is a tuple of the original elements, and the index (with the first one being zero).\n",
    "So `List(\"a\", \"b\", \"c\", \"d\").zipWithIndex` would return `List((\"a\", 0), (\"b\", 1), (\"c\", 2), (\"d\", 3))`\n",
    "\n",
    "This is useful when then element index is needed in some operation.\n",
    "\n",
    "Since this is a pretty straightforward, we'll just have some examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "println(List(1, 2, 3, 4).zipWithIndex)  // note indices start at zero\n",
    "println(List(\"a\", \"b\", \"c\", \"d\").zipWithIndex)\n",
    "println(List((\"a\", \"b\"), (\"c\", \"d\"), (\"e\", \"f\"), (\"g\", \"h\")).zipWithIndex)  // tuples nest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:blue\">**Example: Reduce**</span><br>\n",
    "`List[A].map` has type signature similar to `reduce(op: (A, A) ⇒ A): A`. (it's actually more lenient, `A` only has to be a supertype of the List type, but we're not going to deal with that syntax here)\n",
    "\n",
    "As it's also been explained above, here are some examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "println(List(1, 2, 3, 4).reduce((a, b) => a + b))  // returns the sum of all the elements\n",
    "println(List(1, 2, 3, 4).reduce(_ * _))  // returns the product of all the elements\n",
    "println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _))  // you can chain reduce onto the result of a map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Important note: reduce will fail with an empty list\n",
    "println(List[Int]().reduce(_ * _))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">**Exercise: Reduce**</span><br><a name=\"reduce-exercise\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Now you try: \n",
    "// Fill in the blanks (the ???) such that this returns the product of the double of the elements of the input list.\n",
    "// This should return: (1*2)*(2*2)*(3*2)*(4*2) = 384\n",
    "println(List(1, 2, 3, 4).map(???).reduce(???))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:blue\">**Example: Fold**</span><br>\n",
    "`List[A].fold` is very similar to reduce, except that you can specify the initial accumulation value.\n",
    "It has type signature similar to `fold(z: A)(op: (A, A) ⇒ A): A`. (like `reduce`, the type of `A` is also more lenient)\n",
    "\n",
    "Notably, it takes two argument lists, the first (`z`) is the initial value, and the second is the accumulation function.\n",
    "Unlike `reduce`, it will not fail with an empty list, instead returning the initial value directly.\n",
    "\n",
    "Here's some examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "println(List(1, 2, 3, 4).fold(0)(_ + _))  // equivalent to the sum using reduce\n",
    "println(List(1, 2, 3, 4).fold(1)(_ + _))  // like above, but accumulation starts at 1\n",
    "println(List().fold(1)(_ + _))  // unlike reduce, does not fail on an empty input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">**Exercise: Fold**</span><br><a name=\"fold-exercise\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Now you try: \n",
    "// Fill in the blanks (the ???) such that this returns the double the product of the elements of the input list.\n",
    "// This should return: 2*(1*2*3*4) = 48\n",
    "// Note: unless empty list tolerance is needed, reduce is a much better fit here.\n",
    "println(List(1, 2, 3, 4).fold(???)(???))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">**Exercise: Decoupled Arbiter**</span><br>\n",
    "Let's put everything together now into an exercise.\n",
    "\n",
    "For this example, we're going to build a Decoupled arbiter: a module that has _n_ Decoupled inputs and one Decoupled output. \n",
    "The arbiter selects the lowest channel that is valid and forwards it to the output.\n",
    "\n",
    "Some hints:\n",
    "- Architecturally:\n",
    "  - `io.out.valid` is true if any of the inputs are valid\n",
    "  - Consider having an internal wire of the selected channel\n",
    "  - Each input's `ready` is true if the output is ready, AND that channel is selected (this does combinationally couple ready and valid, but we'll ignore it for now...)\n",
    "- These constructs may help:\n",
    "  - `map`, especially for returning a Vec of sub-elements, for example `io.in.map(_.valid)` returns a list of valid signals of the input Bundles\n",
    "  - `PriorityMux(List[Bool, Bits])`, which takes in a list of valid signals and bits, returning the first element that is valid\n",
    "  - Dynamic index on a Vec, by indexing with a UInt, for example `io.in(0.U)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyRoutingArbiter(numChannels: Int) extends Module {\n",
    "  val io = IO(new Bundle {\n",
    "    val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))\n",
    "    val out = Decoupled(UInt(8.W))\n",
    "  } )\n",
    "\n",
    "  // YOUR CODE BELOW\n",
    "  ???\n",
    "}\n",
    "\n",
    "test(new MyRoutingArbiter(4)) { c =>\n",
    "    // verify that the computation is correct\n",
    "    // Set input defaults\n",
    "    for(i <- 0 until 4) {\n",
    "        c.io.in(i).valid.poke(false.B)\n",
    "        c.io.in(i).bits.poke(i.U)\n",
    "        c.io.out.ready.poke(true.B)\n",
    "    }\n",
    "\n",
    "    c.io.out.valid.expect(false.B)\n",
    "\n",
    "    // Check single input valid behavior with backpressure\n",
    "    for (i <- 0 until 4) {\n",
    "        c.io.in(i).valid.poke(true.B)\n",
    "        c.io.out.valid.expect(true.B)\n",
    "        c.io.out.bits.expect(i.U)\n",
    "\n",
    "        c.io.out.ready.poke(false.B)\n",
    "        c.io.in(i).ready.expect(false.B)\n",
    "\n",
    "        c.io.out.ready.poke(true.B)\n",
    "        c.io.in(i).valid.poke(false.B)\n",
    "    }\n",
    "\n",
    "    // Basic check of multiple input ready behavior with backpressure\n",
    "    c.io.in(1).valid.poke(true.B)\n",
    "    c.io.in(2).valid.poke(true.B)\n",
    "    c.io.out.bits.expect(1.U)\n",
    "    c.io.in(1).ready.expect(true.B)\n",
    "    c.io.in(0).ready.expect(false.B)\n",
    "\n",
    "    c.io.out.ready.poke(false.B)\n",
    "    c.io.in(1).ready.expect(false.B)\n",
    "}\n",
    "\n",
    "println(\"SUCCESS!!\") // Scala Code: if we get here, our tests passed!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"container\"><section id=\"accordion\"><div>\n",
    "<input type=\"checkbox\" id=\"check-1\" />\n",
    "<label for=\"check-1\"><strong>Solution</strong></label>\n",
    "<article>\n",
    "<pre style=\"background-color:#f7f7f7\">\n",
    "class MyRoutingArbiter(numChannels: Int) extends Module {\n",
    "  val io = IO(new Bundle {\n",
    "    val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))\n",
    "    val out = Decoupled(UInt(8.W))\n",
    "  } )\n",
    "\n",
    "  // YOUR CODE BELOW\n",
    "  io.out.valid := io.in.map(\\_.valid).reduce(\\_ || \\_)\n",
    "  val channel = PriorityMux(\n",
    "    io.in.map(\\_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }\n",
    "  )\n",
    "  io.out.bits := io.in(channel).bits\n",
    "  io.in.map(\\_.ready).zipWithIndex.foreach { case (ready, index) =>\n",
    "    ready := io.out.ready && channel === index.U\n",
    "  }\n",
    "}\n",
    "</pre></article></div></section></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# You're done!\n",
    "\n",
    "[Return to the top.](#top)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "nbconvert_exporter": "script",
   "version": "2.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
